{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from minepy import MINE\n",
    "from joblib  import Parallel,delayed\n",
    "import numpy as np\n",
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.展示36个xgb的mic矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mic(x,y):\n",
    "    '''\n",
    "    描述：\n",
    "        计算两个序列的mic\n",
    "    参数：\n",
    "        输入格式为dataframe\n",
    "    '''\n",
    "    if type(x) == pd.DataFrame:\n",
    "        x = x.score.values.ravel()\n",
    "        y = y.score.values.ravel()\n",
    "    m = MINE(est ='mic_e')\n",
    "    m.compute_score(x,y)\n",
    "    return m.mic()\n",
    "\n",
    "def cal_row(x, pred_list, n_jobs=-1):\n",
    "    '''\n",
    "    描述：\n",
    "        计算一组序列中两两元素序列的的mic，返回list\n",
    "    '''\n",
    "    if n_jobs == -1:\n",
    "        n_jobs = len(pred_list)\n",
    "    result = Parallel(n_jobs=len(pred_list),verbose=10)(delayed(mic)(x,y) for y in pred_list)\n",
    "    return result\n",
    "\n",
    "def cal_matrix(pred_list, n_jobs = -1):\n",
    "    '''\n",
    "    描述：\n",
    "        计算mic matrix，思路是两层for循环，内层循环在cal_row函数中实现\n",
    "        缺点是计算量很大，会占用很多进程，可以通过自由设置njob来解决\n",
    "    '''\n",
    "    if n_jobs == -1:\n",
    "        n_jobs = len(pred_list)\n",
    "    \n",
    "    result = Parallel(n_jobs=n_jobs,verbose=10)(delayed(cal_row)(x, pred_list, n_jobs) for x in pred_list)\n",
    "    return np.array(result)\n",
    "\n",
    "def plot_mic_matrix(mic_matrix,ticks):\n",
    "    '''\n",
    "    描述：\n",
    "        根据mic矩阵绘制热力图\n",
    "    '''\n",
    "    plt.figure(figsize=(20,20))\n",
    "    plt.xticks(fontsize=20)\n",
    "    plt.yticks(fontsize=20)\n",
    "    sns.heatmap(mic_matrix,linewidths=0.1,vmax=1.0,\n",
    "            square=True,linecolor='white',annot=True,xticklabels=ticks,yticklabels =ticks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "file_names = [file for file in os.listdir('./preds/') if file[0]!='.']\n",
    "file_names = sorted(file_names,key=lambda x:int(x.split('.')[0][3:]))\n",
    "pred_list = []\n",
    "for file in file_names:\n",
    "    pred = pd.read_csv(os.path.join('./preds/',file))\n",
    "    pred_list.append(pred)\n",
    "    \n",
    "mic_matrix = cal_matrix(pred_list,4)\n",
    "joblib.dump(mic_matrix,'mic_matrix')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_mic_matrix(mic_matrix,file_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.blending"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 给file_names排序\n",
    "file_names = [file for file in os.listdir('./preds/') if file[0]!='.']\n",
    "file_names = sorted(file_names,key=lambda x:int(x.split('.')[0][3:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据热力图，挑选相关度较低的预测结果\n",
    "index = [0,1,3,11,16,17,21,24,26,30,32]\n",
    "pred = pd.DataFrame()\n",
    "pred_0 = pd.read_csv(os.path.join('./preds/',file_names[0]))\n",
    "pred['id'] = pred_0.id\n",
    "pred_prob = pred_0.score\n",
    "for idx in index:\n",
    "    pred_tmp = pd.read_csv(os.path.join('./preds/',file_names[idx]))\n",
    "    pred_prob+=pred_tmp.score\n",
    "pred['prob'] = pred_prob.values\n",
    "pred.to_csv('avg_pred.txt',index=False) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 线上valid-acu\n",
    "auc: 0.83046852887673"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
